{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "initial_id",
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    ""
   ]
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": [
    "# 1、create_sql_query_chain的使用\n",
    "\n",
    "举例1："
   ],
   "id": "1765edddb05fa68e"
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-08-31T10:04:15.345934Z",
     "start_time": "2025-08-31T10:04:15.305709Z"
    }
   },
   "cell_type": "code",
   "source": [
    "# pip install -U langchain langchain-community langchain-openai\n",
    "from langchain_openai import ChatOpenAI\n",
    "from langchain.chains import create_sql_query_chain\n",
    "from langchain_community.utilities import SQLDatabase\n",
    "\n",
    "# 测试连接本地的mysql数据库\n",
    "db_user = \"root\"\n",
    "db_password = \"abc123\"\n",
    "db_host = \"localhost\" #或 127.0.0.1\n",
    "db_port = \"3306\"\n",
    "db_database = \"atguigudb\"\n",
    "# mysql+pymysql://用户名:密码@ip地址:端口号/数据库名\n",
    "db = SQLDatabase.from_uri(f\"mysql+pymysql://{db_user}:{db_password}@{db_host}:{db_port}/{db_database}\")\n",
    "\n",
    "print(\"操作的是哪种数据库：\",db.dialect)\n",
    "print(\"获取数据库中的表：\",db.get_usable_table_names())\n",
    "\n",
    "#执行查询操作\n",
    "res = db.run(\"SELECT COUNT(*) FROM employees\")\n",
    "print(res)"
   ],
   "id": "fed4ce3b241a3bf6",
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "操作的是哪种数据库： mysql\n",
      "获取数据库中的表： ['countries', 'departments', 'employees', 'job_grades', 'job_history', 'jobs', 'locations', 'order', 'regions']\n",
      "[(107,)]\n"
     ]
    }
   ],
   "execution_count": 3
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "举例2：chain的使用",
   "id": "cf8c8a65def2986f"
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-08-31T10:09:44.371738Z",
     "start_time": "2025-08-31T10:09:42.914729Z"
    }
   },
   "cell_type": "code",
   "source": [
    "# pip install -U langchain langchain-community langchain-openai\n",
    "from langchain_openai import ChatOpenAI\n",
    "from langchain.chains import create_sql_query_chain\n",
    "from langchain_community.utilities import SQLDatabase\n",
    "\n",
    "#1、获取mysql数据库的连接\n",
    "# 测试连接本地的mysql数据库\n",
    "db_user = \"root\"\n",
    "db_password = \"abc123\"\n",
    "db_host = \"localhost\" #或 127.0.0.1\n",
    "db_port = \"3306\"\n",
    "db_database = \"atguigudb\"\n",
    "# mysql+pymysql://用户名:密码@ip地址:端口号/数据库名\n",
    "db = SQLDatabase.from_uri(f\"mysql+pymysql://{db_user}:{db_password}@{db_host}:{db_port}/{db_database}\")\n",
    "\n",
    "# 2、获取大语言模型\n",
    "import os\n",
    "import dotenv\n",
    "from langchain_openai import ChatOpenAI\n",
    "\n",
    "dotenv.load_dotenv()\n",
    "\n",
    "os.environ['OPENAI_API_KEY'] = os.getenv(\"OPENAI_API_KEY1\")\n",
    "os.environ['OPENAI_BASE_URL'] = os.getenv(\"OPENAI_BASE_URL\")\n",
    "chat_model = ChatOpenAI(model=\"gpt-4o-mini\")\n",
    "\n",
    "# 3、创建create_sql_query_chain的实例\n",
    "chain = create_sql_query_chain(chat_model, db)\n",
    "# response = chain.invoke({\"question\": \"数据表employees中一共有多少个员工？\",\n",
    "#                          \"table_names_to_use\":[\"employees\"]})\n",
    "# print(response)\n",
    "\n",
    "response = chain.invoke({\"question\": \"数据表employees中薪资最高的员工信息\",\n",
    "                         \"table_names_to_use\":[\"employees\"]})\n",
    "print(response)"
   ],
   "id": "484e8743e6497a50",
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "D:\\developTools\\miniconda3\\envs\\pyth310\\lib\\site-packages\\langchain_community\\utilities\\sql_database.py:338: SAWarning: Cannot correctly sort tables; there are unresolvable cycles between tables \"departments, employees\", which is usually caused by mutually dependent foreign key constraints.  Foreign key constraints involving these tables will not be considered; this warning may raise an error in a future release.\n",
      "  metadata_table_names = [tbl.name for tbl in self._metadata.sorted_tables]\n",
      "D:\\developTools\\miniconda3\\envs\\pyth310\\lib\\site-packages\\langchain_community\\utilities\\sql_database.py:350: SAWarning: Cannot correctly sort tables; there are unresolvable cycles between tables \"departments, employees\", which is usually caused by mutually dependent foreign key constraints.  Foreign key constraints involving these tables will not be considered; this warning may raise an error in a future release.\n",
      "  for tbl in self._metadata.sorted_tables\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Question: 数据表employees中薪资最高的员工信息\n",
      "SQLQuery: SELECT `employee_id`, `first_name`, `last_name`, `email`, `salary` FROM `employees` ORDER BY `salary` DESC LIMIT 1\n"
     ]
    }
   ],
   "execution_count": 8
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "# 2、create_stuff_documents_chain的使用",
   "id": "e35ead3b5aae2137"
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2025-08-31T10:17:17.835698Z",
     "start_time": "2025-08-31T10:17:16.382335Z"
    }
   },
   "cell_type": "code",
   "source": [
    "from langchain.chains.combine_documents import create_stuff_documents_chain\n",
    "from langchain_core.prompts import PromptTemplate\n",
    "from langchain_openai import ChatOpenAI\n",
    "from langchain_core.documents import Document\n",
    "\n",
    "# 定义提示词模板\n",
    "prompt = PromptTemplate.from_template(\"\"\"\n",
    "基于文档{docs}中说的情况，香蕉是什么颜色的？\n",
    "\"\"\")\n",
    "\n",
    "# 创建链\n",
    "llm = ChatOpenAI(model=\"gpt-4o-mini\")\n",
    "chain = create_stuff_documents_chain(llm, prompt, document_variable_name=\"docs\")\n",
    "\n",
    "# 文档输入\n",
    "docs123 = [\n",
    "    Document(\n",
    "        page_content=\"苹果，学名Malus pumila Mill.，别称西洋苹果、柰，属于蔷薇科苹果属的植物。苹果是全球最广泛种植和销售的水果之一，具有悠久的栽培历史和广泛的分布范围。苹果的原始种群主要起源于中亚的天山山脉附近，尤其是现代哈萨克斯坦的阿拉木图地区，提供了所有现代苹果品种的基因库。苹果通过早期的贸易路线，如丝绸之路，从中亚向外扩散到全球各地。\"\n",
    "    ),\n",
    "    Document(\n",
    "        page_content=\"香蕉是白色的水果，主要产自热带地区。\"\n",
    "\n",
    "    ),\n",
    "    Document(\n",
    "        page_content=\"蓝莓是蓝色的浆果，含有抗氧化物质。\"\n",
    "\n",
    "    )\n",
    "]\n",
    "# 执行摘要\n",
    "chain.invoke({\"docs\": docs123})"
   ],
   "id": "3987a26f3dbc406f",
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'根据你提供的信息，香蕉被描述为“白色的水果”。但实际上，香蕉的外表通常是黄色的，尤其是在成熟时。可能这里存在一个误解或错误的描述，通常我们所说的香蕉是黄色的，而未熟的香蕉是绿色的，过熟的香蕉则可能变为棕色。'"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "execution_count": 14
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
